#!/usr/bin/env python3
"""
LLM Values Analysis Tool

This script processes and analyzes data from different prompting strategies used to elicit
value responses from Large Language Models. It handles data preparation, value calculations,
and statistical analysis, with results saved to files.

Supported models: GPT, Gemini, Llama 8B, Llama 70B, Gemma 9B, Gemma 27B
Supported prompt types: basic, bwvr, demographic, persona, names
Supported prompting methods: batch, serial

Usage:
    python llm_values_analysis.py --prompt_type PROMPT_TYPE [--model MODEL] [--temperature TEMP]
                                  [--prompting_method METHOD] [--compare_methods] [--output_dir DIR]

Arguments:
    --prompt_type       Type of prompt strategy used for analysis: basic, bwvr, demographic, persona, names
    --model             LLM model to analyze (gpt, gemini, llama8b, llama70b, gemma9b, gemma27b), defaults to all
    --temperature       Temperature setting to analyze (0.0, 0.7, or "all")
    --prompting_method  Method used for prompting: batch or serial, defaults to both
    --compare_methods   Compare batch vs. serial prompting for the same model(s)
    --output_dir        Directory to save output files, defaults to "results/"
"""

import os
import sys
import argparse
import pandas as pd
import numpy as np
from scipy import stats
import math

# Constants for value calculations
VALUE_GROUPS = {
    'Self-direction Thought': ['PVQ1', 'PVQ23', 'PVQ39'],
    'Self-direction Action': ['PVQ16', 'PVQ30', 'PVQ56'],
    'Stimulation': ['PVQ10', 'PVQ28', 'PVQ43'],
    'Hedonism': ['PVQ3', 'PVQ36', 'PVQ46'],
    'Achievement': ['PVQ17', 'PVQ32', 'PVQ48'],
    'Power Dominance': ['PVQ6', 'PVQ29', 'PVQ41'],
    'Power Resources': ['PVQ12', 'PVQ20', 'PVQ44'],
    'Face': ['PVQ9', 'PVQ24', 'PVQ49'],
    'Security Personal': ['PVQ13', 'PVQ26', 'PVQ53'],
    'Security Societal': ['PVQ2', 'PVQ35', 'PVQ50'],
    'Tradition': ['PVQ18', 'PVQ33', 'PVQ40'],
    'Conformity-Rules': ['PVQ15', 'PVQ31', 'PVQ42'],
    'Conformity-Interpersonal': ['PVQ4', 'PVQ22', 'PVQ51'],
    'Humility': ['PVQ7', 'PVQ38', 'PVQ54'],
    'Universalism-Nature': ['PVQ8', 'PVQ21', 'PVQ45'],
    'Universalism-Concern': ['PVQ5', 'PVQ37', 'PVQ52'],
    'Universalism-Tolerance': ['PVQ14', 'PVQ34', 'PVQ57'],
    'Benevolence-Care': ['PVQ11', 'PVQ25', 'PVQ47'],
    'Benevolence-Dependability': ['PVQ19', 'PVQ27', 'PVQ55']
}

# Benchmark values from Schwartz & Cieciuch (2022), 50th percentile
BENCHMARK_VALUES = {
    'c_Benevolence-Care': 0.794,
    'c_Benevolence-Dependability': 0.726,
    'c_Self-direction Action': 0.597,
    'c_Self-direction Thought': 0.582,
    'c_Universalism-Concern': 0.502,
    'c_Universalism-Tolerance': 0.37,
    'c_Security Societal': 0.322,
    'c_Security Personal': 0.281,
    'c_Hedonism': 0.228,
    'c_Achievement': 0.078,
    'c_Face': 0.047,
    'c_Universalism-Nature': -0.105,
    'c_Stimulation': -0.11,
    'c_Conformity-Interpersonal': -0.162,
    'c_Humility': -0.205,
    'c_Conformity-Rules': -0.257,
    'c_Tradition': -0.719,
    'c_Power Resources': -1.332,
    'c_Power Dominance': -1.403
}

# Map of model names and their simplified versions for file names
MODEL_NAMES = {
    'gpt': 'gpt',
    'gemini': 'gemini',
    'llama8b': 'llama8b',
    'llama70b': 'llama70b',
    'gemma9b': 'gemma9b',
    'gemma27b': 'gemma27b',
    'all': 'all'
}

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Process and analyze LLM value data.')
    parser.add_argument('--prompt_type', type=str, required=True, 
                        choices=['basic', 'bwvr', 'demographic', 'persona', 'names'],
                        help='Type of prompt strategy to analyze')
    parser.add_argument('--model', type=str, default='all',
                        choices=['gpt', 'gemini', 'llama8b', 'llama70b', 'gemma9b', 'gemma27b', 'all'],
                        help='LLM model to analyze')
    parser.add_argument('--temperature', type=str, default='all',
                        choices=['0.0', '0.7', 'all'],
                        help='Temperature setting to analyze (0.0, 0.7, or "all")')
    parser.add_argument('--prompting_method', type=str, default='both',
                        choices=['batch', 'serial', 'both'],
                        help='Method used for prompting: batch or serial')
    parser.add_argument('--compare_methods', action='store_true',
                        help='Compare batch vs. serial prompting for the same model')
    parser.add_argument('--output_dir', type=str, default='results',
                        help='Directory to save output files')
    
    return parser.parse_args()

def ensure_dir(directory):
    """Create directory if it doesn't exist."""
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created directory: {directory}")

def find_input_files(prompt_type, model, temperature, prompting_method):
    """
    Find input files matching the specified criteria.
    
    Args:
        prompt_type (str): Type of prompt strategy
        model (str): Model name
        temperature (str): Temperature setting
        prompting_method (str): Prompting method (batch or serial)
    
    Returns:
        list: List of matching file paths
    """
    # This is a placeholder - in a real implementation, this would scan the input directory
    # For this example, we'll use a naming convention pattern
    
    if model == 'all':
        models = list(MODEL_NAMES.keys())
        models.remove('all')
    else:
        models = [model]
    
    if temperature == 'all':
        temperatures = ['00', '07']  # Representing 0.0 and 0.7
    else:
        # Convert temperature like "0.7" to "07"
        temp_str = temperature.replace('.', '')
        temperatures = [temp_str]
    
    if prompting_method == 'both':
        methods = ['batch', 'serial']
    else:
        methods = [prompting_method]
    
    files = []
    for m in models:
        for t in temperatures:
            for method in methods:
                # Use a naming convention with model, prompting method and temperature
                # Pattern: output_{model}_{prompt_type}_{method}_{temp}.csv
                file_name = f"output_{m}_{prompt_type}_{method}_{t}.csv"
                
                # In a real implementation, check if file exists here
                # if os.path.exists(file_name):
                files.append(file_name)
    
    return files

def prepare_dataset(input_file, output_dir):
    """
    Prepare a dataset by pivoting and calculating values.
    
    Args:
        input_file (str): Path to input CSV file
        output_dir (str): Directory to save the prepared data
    
    Returns:
        pandas.DataFrame: Prepared dataframe with calculated values
        dict: Metadata about the dataset
    """
    print(f"Preparing dataset: {input_file}")
    try:
        # Read the input file
        df = pd.read_csv(input_file)
        
        # Extract metadata from the filename
        file_parts = os.path.basename(input_file).split('_')
        metadata = {
            'model': file_parts[1],
            'prompt_type': file_parts[2],
            'prompting_method': file_parts[3],
            'temperature': file_parts[4].split('.')[0]
        }
        
        # Pivot the DataFrame to convert rows into columns
        pivoted_df = df.pivot(index=['Date', 'Iteration'], columns='Number', values='Rating')
        
        # Reset the index to make 'Date' and 'Iteration' columns again
        pivoted_df.reset_index(inplace=True)
        
        # Calculate value group means
        for value_name, pvq_items in VALUE_GROUPS.items():
            # Filter to only include PVQ items that exist in the dataframe
            existing_items = [item for item in pvq_items if item in pivoted_df.columns]
            if existing_items:
                pivoted_df[value_name] = pivoted_df[existing_items].mean(axis=1)
        
        # Calculate MRAT (mean rating across all PVQ items)
        pvq_columns = [col for col in pivoted_df.columns if col.startswith('PVQ')]
        pivoted_df['MRAT'] = pivoted_df[pvq_columns].mean(axis=1)
        
        # Calculate centered values (value - MRAT)
        for value_name in VALUE_GROUPS.keys():
            if value_name in pivoted_df.columns:
                pivoted_df[f'c_{value_name}'] = pivoted_df[value_name] - pivoted_df['MRAT']
        
        # Save the prepared dataset
        base_name = os.path.splitext(os.path.basename(input_file))[0]
        output_file = os.path.join(output_dir, f"{base_name}_prepared.csv")
        pivoted_df.to_csv(output_file, index=False)
        print(f"Prepared dataset saved to: {output_file}")
        
        return pivoted_df, metadata
    
    except Exception as e:
        print(f"Error preparing dataset {input_file}: {str(e)}")
        return None, None

def calculate_rankings(df, output_dir, file_prefix, metadata):
    """
    Calculate rankings of centered values and compare to benchmark.
    
    Args:
        df (pandas.DataFrame): Prepared dataframe with centered values
        output_dir (str): Directory to save results
        file_prefix (str): Prefix for output files
        metadata (dict): Metadata about the dataset
    
    Returns:
        float: Spearman correlation with benchmark
        pandas.DataFrame: DataFrame with rankings data
    """
    print(f"Calculating rankings for {file_prefix}")
    try:
        # List of centered value columns
        centered_columns = [col for col in df.columns if col.startswith('c_')]
        
        # Calculate mean values for centered columns
        mean_values = df[centered_columns].mean()
        
        # Create a dataframe for rankings
        rankings_df = pd.DataFrame({'Column': mean_values.index, 'Centered Mean': mean_values.values})
        
        # Sort by centered mean values
        ranked_df = rankings_df.sort_values(by='Centered Mean', ascending=False)
        ranked_df.reset_index(drop=True, inplace=True)
        
        # Add metadata columns
        for key, value in metadata.items():
            ranked_df[key] = value
        
        # Save rankings
        rankings_file = os.path.join(output_dir, f"{file_prefix}_rankings.csv")
        ranked_df.to_csv(rankings_file, index=False)
        print(f"Rankings saved to: {rankings_file}")
        
        # Calculate Spearman correlation with benchmark values
        benchmark_df = pd.DataFrame(list(BENCHMARK_VALUES.items()), columns=['Column', 'Benchmark'])
        
        # Merge with rankings
        correlation_df = pd.merge(benchmark_df, ranked_df, on='Column', how='inner')
        
        correlation = None
        p_value = None
        
        if len(correlation_df) > 0:
            # Calculate correlation
            correlation, p_value = stats.spearmanr(correlation_df['Benchmark'], correlation_df['Centered Mean'])
            
            # Add rank columns
            correlation_df['Benchmark Rank'] = correlation_df['Benchmark'].rank(ascending=False)
            correlation_df['Model Rank'] = correlation_df['Centered Mean'].rank(ascending=False)
            
            # Save correlation results
            correlation_file = os.path.join(output_dir, f"{file_prefix}_correlation.csv")
            correlation_df.to_csv(correlation_file, index=False)
            
            # Save correlation value to a summary file
            summary_file = os.path.join(output_dir, f"{file_prefix}_summary.txt")
            with open(summary_file, 'w') as f:
                f.write(f"Spearman Rank Correlation: {correlation:.4f}\n")
                f.write(f"P-value: {p_value:.4f}\n")
                f.write(f"Model: {metadata['model']}\n")
                f.write(f"Prompt Type: {metadata['prompt_type']}\n")
                f.write(f"Prompting Method: {metadata['prompting_method']}\n")
                f.write(f"Temperature: {metadata['temperature']}\n")
            
            print(f"Correlation: {correlation:.4f} (p={p_value:.4f})")
            print(f"Correlation results saved to: {correlation_file}")
            print(f"Summary saved to: {summary_file}")
        else:
            print("Warning: No matching columns found between data and benchmark.")
        
        return correlation, ranked_df
    
    except Exception as e:
        print(f"Error calculating rankings: {str(e)}")
        return None, None

def visualize_value_profile(df, output_dir, file_prefix, metadata):
    """
    Save value profile data.
    
    Args:
        df (pandas.DataFrame): Prepared dataframe with centered values
        output_dir (str): Directory to save data
        file_prefix (str): Prefix for output files
        metadata (dict): Metadata about the dataset
    """
    print(f"Saving value profile data for {file_prefix}")
    try:
        # List of centered value columns
        centered_columns = [col for col in df.columns if col.startswith('c_')]
        
        if not centered_columns:
            print("No centered value columns found.")
            return
        
        # Calculate mean values for centered columns
        mean_values = df[centered_columns].mean().reset_index()
        mean_values.columns = ['Value', 'Score']
        
        # Clean value names for display
        mean_values['Value'] = mean_values['Value'].str.replace('c_', '')
        
        # Sort by score
        mean_values = mean_values.sort_values('Score', ascending=False)
        
        # Add metadata columns
        for key, value in metadata.items():
            mean_values[key] = value
        
        # Save the data
        profile_file = os.path.join(output_dir, f"{file_prefix}_value_profile.csv")
        mean_values.to_csv(profile_file, index=False)
        print(f"Value profile data saved to: {profile_file}")
    
    except Exception as e:
        print(f"Error saving value profile data: {str(e)}")

def compare_models(correlations_data, output_dir, prompt_type):
    """
    Compare correlations across different models.
    
    Args:
        correlations_data (list): List of tuples with (correlation, metadata, ranked_df)
        output_dir (str): Directory to save comparison results
        prompt_type (str): Type of prompt strategy
    """
    print(f"Comparing models for prompt type: {prompt_type}")
    try:
        # Create a dataframe to store model correlations
        comparison_data = []
        
        for corr, metadata, _ in correlations_data:
            if corr is not None:
                comparison_data.append({
                    'Model': metadata['model'].upper(),
                    'Prompting_Method': metadata['prompting_method'].capitalize(),
                    'Temperature': metadata['temperature'],
                    'Correlation': corr
                })
        
        if not comparison_data:
            print("No valid correlation data found for comparison.")
            return
        
        # Create comparison dataframe
        comparison_df = pd.DataFrame(comparison_data)
        
        # Save comparison data
        comparison_file = os.path.join(output_dir, f"{prompt_type}_model_comparison.csv")
        comparison_df.to_csv(comparison_file, index=False)
        print(f"Model comparison data saved to: {comparison_file}")
    
    except Exception as e:
        print(f"Error comparing models: {str(e)}")

def fisher_z_transform(r):
    """
    Apply Fisher's Z transformation to correlation coefficient r
    """
    # Check if r is exactly 1 or -1 to avoid division by zero
    if r == 1.0:
        r = 0.9999
    elif r == -1.0:
        r = -0.9999
    
    return 0.5 * np.log((1 + r) / (1 - r))

def compare_correlations(r1, r2, n, alpha=0.05):
    """
    Compare two correlation coefficients using Fisher's Z transformation
    
    Parameters:
    r1, r2: The two correlation coefficients to compare
    n: Sample size (must be the same for both correlations)
    alpha: Significance level
    
    Returns:
    Dictionary with comparison results
    """
    # Apply Fisher's Z transformation
    z1 = fisher_z_transform(r1)
    z2 = fisher_z_transform(r2)
    
    # Calculate standard error
    se = np.sqrt(2 / (n - 3))
    
    # Calculate z statistic
    z_diff = (z1 - z2) / se
    
    # Calculate two-tailed p-value
    p_value = 2 * (1 - stats.norm.cdf(abs(z_diff)))
    
    return {
        'z_statistic': z_diff,
        'p_value': p_value,
        'significant': p_value < alpha,
        'r1': r1,
        'r2': r2,
        'difference': r1 - r2
    }

def compare_prompting_methods(correlations_data, output_dir, model, prompt_type):
    """
    Compare batch vs. serial prompting methods for the same model and prompt type.
    
    Args:
        correlations_data (list): List of tuples with (correlation, metadata, ranked_df)
        output_dir (str): Directory to save comparison results
        model (str): Model to compare or 'all' for all models
        prompt_type (str): Type of prompt strategy
    """
    print(f"Comparing batch vs. serial prompting for {model} with {prompt_type} prompt")
    
    # Number of values being ranked (19 values in Schwartz model)
    sample_size = 19
    
    try:
        # Organize data by model and temperature
        comparison_data = {}
        
        for corr, metadata, _ in correlations_data:
            if corr is not None:
                # Create a key that identifies this configuration
                key = f"{metadata['model']}_{metadata['temperature']}"
                
                if key not in comparison_data:
                    comparison_data[key] = {}
                
                comparison_data[key][metadata['prompting_method']] = corr
        
        # Check which keys have both batch and serial data
        valid_comparisons = []
        
        for key, methods in comparison_data.items():
            if 'batch' in methods and 'serial' in methods:
                model_name, temp = key.split('_')
                
                # Only include the specified model unless 'all' is selected
                if model != 'all' and model_name != model:
                    continue
                
                # Compare the correlations
                result = compare_correlations(
                    methods['batch'],
                    methods['serial'],
                    sample_size
                )
                
                # Add metadata to result
                result['model'] = model_name
                result['temperature'] = temp
                result['category'] = f"{prompt_type}_{temp}"
                result['batch_correlation'] = methods['batch']
                result['serial_correlation'] = methods['serial']
                
                valid_comparisons.append(result)
        
        if not valid_comparisons:
            print(f"No valid comparisons found for {model} with {prompt_type} prompt.")
            return
        
        # Create DataFrame from results
        df_results = pd.DataFrame(valid_comparisons)
        
        # Rearrange columns for better readability
        df_results = df_results[[
            'model', 'category', 'batch_correlation', 'serial_correlation', 
            'difference', 'z_statistic', 'p_value', 'significant'
        ]]
        
        # Rename columns
        df_results.columns = [
            'Model', 'Category', 'Batch_Correlation', 'Serial_Correlation', 
            'Difference', 'Z_Statistic', 'P_Value', 'Significant'
        ]
        
        # Apply Bonferroni correction for multiple comparisons
        alpha_corrected = 0.05 / len(df_results)
        df_results['Significant_Bonferroni'] = df_results['P_Value'] < alpha_corrected
        
        # Save to CSV
        output_file = os.path.join(output_dir, f"{model}_{prompt_type}_batch_serial_comparison.csv")
        df_results.to_csv(output_file, index=False)
        print(f"Batch vs. serial comparison saved to: {output_file}")
        
        # Print summary statistics
        print("\nComparison Summary:")
        print(f"Total comparisons: {len(df_results)}")
        print(f"Significant differences (α=0.05): {df_results['Significant'].sum()}")
        print(f"Significant differences after Bonferroni correction (α={alpha_corrected:.6f}): {df_results['Significant_Bonferroni'].sum()}")
        
        # Display where batch performed better vs where serial performed better
        batch_better = df_results[df_results['Batch_Correlation'] > df_results['Serial_Correlation']]
        serial_better = df_results[df_results['Batch_Correlation'] < df_results['Serial_Correlation']]
        equal_performance = df_results[df_results['Batch_Correlation'] == df_results['Serial_Correlation']]
        
        print(f"\nCategories where Batch performed better: {len(batch_better)}")
        for _, row in batch_better.iterrows():
            print(f"  {row['Model']} - {row['Category']}: Batch={row['Batch_Correlation']:.2f}, Serial={row['Serial_Correlation']:.2f}, diff={row['Difference']:.2f}, p={row['P_Value']:.4f}")
        
        print(f"\nCategories where Serial performed better: {len(serial_better)}")
        for _, row in serial_better.iterrows():
            print(f"  {row['Model']} - {row['Category']}: Batch={row['Batch_Correlation']:.2f}, Serial={row['Serial_Correlation']:.2f}, diff={row['Difference']:.2f}, p={row['P_Value']:.4f}")
        
        print(f"\nCategories with equal performance: {len(equal_performance)}")
        for _, row in equal_performance.iterrows():
            print(f"  {row['Model']} - {row['Category']}: Batch={row['Batch_Correlation']:.2f}, Serial={row['Serial_Correlation']:.2f}")
    
    except Exception as e:
        print(f"Error comparing prompting methods: {str(e)}")

def main():
    """Main function to process and analyze the data."""
    # Parse command line arguments
    args = parse_arguments()
    
    # Ensure output directory exists
    output_dir = args.output_dir
    ensure_dir(output_dir)
    
    # Create a subdirectory for the prompt type
    prompt_output_dir = os.path.join(output_dir, args.prompt_type)
    ensure_dir(prompt_output_dir)
    
    # Find input files
    input_files = find_input_files(args.prompt_type, args.model, args.temperature, args.prompting_method)
    
    if not input_files:
        print(f"No input files found for prompt type '{args.prompt_type}', model '{args.model}', temperature '{args.temperature}', method '{args.prompting_method}'.")
        return
    
    print(f"Found {len(input_files)} input files to process.")
    
    # Process each input file
    correlations_data = []  # Store (correlation, metadata, ranked_df) tuples
    
    for input_file in input_files:
        # For this example, we'll assume the file exists
        print(f"Processing file: {input_file}")
        
        # Prepare the dataset
        prepared_df, metadata = prepare_dataset(input_file, prompt_output_dir)
        
        if prepared_df is not None and metadata is not None:
            # Create file prefix for outputs
            file_prefix = os.path.splitext(os.path.basename(input_file))[0]
            
            # Calculate rankings and correlations
            correlation, ranked_df = calculate_rankings(prepared_df, prompt_output_dir, file_prefix, metadata)
            
            # Save value profile data
            visualize_value_profile(prepared_df, prompt_output_dir, file_prefix, metadata)
            
            # Store correlation data for comparisons
            if correlation is not None and ranked_df is not None:
                correlations_data.append((correlation, metadata, ranked_df))
    
    # Compare batch vs. serial prompting methods if requested
    if args.compare_methods and len(correlations_data) > 1:
        compare_prompting_methods(correlations_data, prompt_output_dir, args.model, args.prompt_type)
    
    print(f"Analysis complete. Results saved to {prompt_output_dir}")

if __name__ == "__main__":
    main()